package main

import (
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"net"
	"os"
	"path"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/pkg/sftp"

	"golang.org/x/crypto/ssh"
	"gopkg.in/yaml.v2"
)

var ipAddrs chan string = make(chan string)
var wg sync.WaitGroup

type HostInfo struct {
	Roomname string   `yaml:"roomname"`
	IpList   []string `yaml:"ip"`
}

type Group struct {
	List []*HostInfo `yaml:"groups"`
}

func OpencfgFile() {

	databytes, err := ioutil.ReadFile("roomname.yaml")

	if err != nil {
		log.Fatal(err)
	}
	g := new(Group)
	if err = yaml.Unmarshal(databytes, &g); err != nil {
		log.Fatal(err)
	}

	for _, hostinfo := range g.List {
		if os.Args[1] == "all" {

			for _, i := range hostinfo.IpList {
				ipAddrs <- hostinfo.Roomname + " " + i

			}

		} else if hostinfo.Roomname == os.Args[1] {

			for _, i := range hostinfo.IpList {

				ipAddrs <- hostinfo.Roomname + " " + i
			}

		}

	}
	close(ipAddrs)
}

func Opsee() {
	for s := range ipAddrs {
		wg.Add(1)
		str1 := strings.Split(s, " ")
		room := &str1[0]
		ip := &str1[1]
		user := &str1[2]
		passwd := &str1[3]
		comm := os.Args[2]
		if os.Args[2] == "scp" {
			go SftpRun(*room, *user, *passwd, *ip)

		} else {
			go Runshell(*room, *user, *passwd, *ip, comm)
		}

	}
	wg.Wait()
	if len(ipAddrs) == 0 {
		fmt.Println("End of channel data reading")
	}
}

func sshSession(user, password, host string, port int) (sshSession *ssh.Session, err error) {
	sshClient, err := connector(user, password, host, port)
	if err != nil {
		//fmt.Println(room + "." + ip)
		//fmt.Println("连接失败", err)
		return
	}

	if sshSession, err = sshClient.NewSession(); err != nil {
		fmt.Println("创建客户端失败", err)
		return
	}

	return
}

func connector(user, password, host string, port int) (sshClient *ssh.Client, err error) {
	auth := make([]ssh.AuthMethod, 0)
	auth = append(auth, ssh.Password(password))

	clientConfig := &ssh.ClientConfig{
		User:    user,
		Auth:    auth,
		Timeout: 1 * time.Second,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
	}

	addr := host + ":" + strconv.Itoa(port)
	sshClient, err = ssh.Dial("tcp", addr, clientConfig)
	if err != nil {
		//fmt.Println("连接ssh失败", err)
		return
	}

	return
}
func sftpconnect(user, password, host string, port int) (*sftp.Client, error) {
	var (
		auth         []ssh.AuthMethod
		addr         string
		clientConfig *ssh.ClientConfig
		sshClient    *ssh.Client
		sftpClient   *sftp.Client
		err          error
	)

	auth = make([]ssh.AuthMethod, 0)
	auth = append(auth, ssh.Password(password))

	clientConfig = &ssh.ClientConfig{
		User:    user,
		Auth:    auth,
		Timeout: 30 * time.Second,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
	}

	addr = fmt.Sprintf("%s:%d", host, port)

	if sshClient, err = ssh.Dial("tcp", addr, clientConfig); err != nil {
		return nil, err
	}

	if sftpClient, err = sftp.NewClient(sshClient); err != nil {
		return nil, err
	}

	return sftpClient, nil
}
func SftpRun(room, user, passwd, ip string) {
	defer wg.Done()
	var (
		sftpClient *sftp.Client
	)

	sftpClient, _ = sftpconnect(user, passwd, ip, 22)

	defer sftpClient.Close()

	localFilePath := os.Args[3]
	remoteDir := os.Args[4]
	srcFile, err := os.Open(localFilePath)
	if err != nil {
		log.Fatal(err)
	}
	defer srcFile.Close()

	var remoteFileName = path.Base(localFilePath)
	dstFile, err := sftpClient.Create(path.Join(remoteDir, remoteFileName))
	if err != nil {
		log.Fatal(err)
	}
	defer dstFile.Close()

	buf := make([]byte, 1024)
	for {
		n, _ := srcFile.Read(buf)
		if n == 0 {
			break
		}
		dstFile.Write(buf[0:n])
	}

	fmt.Println(room + "." + ip + "\n" + "copy file  finished!")
}
func Runshell(room, user, passwd, ip, comm string) {

	defer wg.Done()
	session, err := sshSession(user, passwd, ip, 22)

	if err != nil {

		log.Printf(room+"."+ip+"\n"+"result:", err)
		return
	}

	buf, _ := session.CombinedOutput(comm)

	fmt.Println(room + "." + ip + "\n" + "result:" + string(buf))
	defer session.Close()

}
func main() {
	flag.Parse()
	if flag.NArg() == 2 && os.Args[2] != "scp" {
		go OpencfgFile()
		Opsee()

	} else if flag.NArg() == 4 && os.Args[2] == "scp" {
		go OpencfgFile()
		Opsee()
	} else {

		fmt.Println("本go 只有两个功能，批量执行命令和scp下发文件" + "\n" + "正确执行方式，1：hosts组，2：执行的命令;" + "\n" + "例如 ./gosh all \"date\";" + "\n" + "如果下发文件，1：hosts组，2：scp 3：本地文件路径，4：远端目录;" + "\n" + "例如 ./gosh all scp \"/tmp/test.log\" \"/tmp\"")
		return
	}

}
